{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reinplementation of Block Coordinate Descent (BCD) Algorithm for Training DNNs (3-layer MLP) for MNIST in PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch Version: 1.0.0\n",
      "Torchvision Version: 0.2.1\n",
      "GPU is available? True\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "5 runs of 50 epochs, seed = 10, 20, 30, 40, 50; \n",
    "validation accuracies: 0.9492, 0.9457, 0.9463, 0.9439, 0.9455\n",
    "\"\"\"\n",
    "from __future__ import print_function, division\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models, transforms, utils\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import os\n",
    "import copy\n",
    "\n",
    "print(\"PyTorch Version:\", torch.__version__)\n",
    "print(\"Torchvision Version:\", torchvision.__version__)\n",
    "print(\"GPU is available?\", torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read in MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dtype = torch.float\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Convert to tensor and scale to [0, 1]\n",
    "ts = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0,), (1,))])\n",
    "mnist_trainset = datasets.MNIST('../data', train=True, download=True, transform=ts)\n",
    "mnist_testset = datasets.MNIST(root='../data', train=False, download=True, transform=ts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data manipulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manipulate train set\n",
    "x_d0 = mnist_trainset[0][0].size()[0]\n",
    "x_d1 = mnist_trainset[0][0].size()[1]\n",
    "x_d2 = mnist_trainset[0][0].size()[2]\n",
    "N = x_d3 = len(mnist_trainset)\n",
    "K = 10\n",
    "x_train = torch.empty((N,x_d0*x_d1*x_d2), device=device)\n",
    "y_train = torch.empty(N, dtype=torch.long)\n",
    "for i in range(N): \n",
    "     x_train[i,:] = torch.reshape(mnist_trainset[i][0], (1, x_d0*x_d1*x_d2))\n",
    "     y_train[i] = mnist_trainset[i][1]\n",
    "x_train = torch.t(x_train)\n",
    "y_one_hot = torch.zeros(N, K).scatter_(1, torch.reshape(y_train, (N, 1)), 1)\n",
    "y_one_hot = torch.t(y_one_hot).to(device=device)\n",
    "y_train = y_train.to(device=device)\n",
    "\n",
    "# Manipulate test set\n",
    "N_test = x_d3_test = len(mnist_testset)\n",
    "x_test = torch.empty((N_test,x_d0*x_d1*x_d2), device=device)\n",
    "y_test = torch.empty(N_test, dtype=torch.long)\n",
    "for i in range(N_test): \n",
    "     x_test[i,:] = torch.reshape(mnist_testset[i][0], (1, x_d0*x_d1*x_d2))\n",
    "     y_test[i] = mnist_testset[i][1]\n",
    "x_test = torch.t(x_test)\n",
    "y_test_one_hot = torch.zeros(N_test, K).scatter_(1, torch.reshape(y_test, (N_test, 1)), 1)\n",
    "y_test_one_hot = torch.t(y_test_one_hot).to(device=device)\n",
    "y_test = y_test.to(device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main algorithm (Jinshan's Algorithm in Zeng et al (2018))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameter initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 40\n",
    "torch.manual_seed(seed)\n",
    "if torch.cuda.is_available():\n",
    "    torch.manual_seed(seed)\n",
    "d0 = x_d0*x_d1*x_d2\n",
    "d1 = d2 = d3 = 1500\n",
    "d4 = K # Layers: input + 3 hidden + output\n",
    "W1 = 0.01*torch.randn(d1, d0, device=device)\n",
    "b1 = 0.1*torch.ones(d1, 1, device=device)\n",
    "W2 = 0.01*torch.randn(d2, d1, device=device)\n",
    "b2 = 0.1*torch.ones(d2, 1, device=device)\n",
    "W3 = 0.01*torch.randn(d3, d2, device=device)\n",
    "b3 = 0.1*torch.ones(d3, 1, device=device)\n",
    "W4 = 0.01*torch.randn(d4, d3, device=device)\n",
    "b4 = 0.1*torch.ones(d4, 1, device=device)\n",
    "\n",
    "U1 = torch.addmm(b1.repeat(1, N), W1, x_train)\n",
    "V1 = nn.ReLU()(U1)\n",
    "U2 = torch.addmm(b2.repeat(1, N), W2, V1)\n",
    "V2 = nn.ReLU()(U2)\n",
    "U3 = torch.addmm(b3.repeat(1, N), W3, V2)\n",
    "V3 = nn.ReLU()(U3) \n",
    "U4 = torch.addmm(b4.repeat(1, N), W4, V3)\n",
    "V4 = U4\n",
    "\n",
    "gamma = 1\n",
    "gamma1 = gamma2 = gamma3 = gamma4 = gamma\n",
    "\n",
    "rho = 1\n",
    "rho1 = rho2 = rho3 = rho4 = rho\n",
    "\n",
    "\n",
    "alpha = 5\n",
    "alpha1 = alpha2 = alpha3 = alpha4 = alpha5 = alpha6 = alpha7 \\\n",
    "= alpha8 = alpha9 = alpha10 = alpha\n",
    "\n",
    "niter = 50\n",
    "loss1 = np.empty(niter)\n",
    "loss2 = np.empty(niter)\n",
    "accuracy_train = np.empty(niter)\n",
    "accuracy_test = np.empty(niter)\n",
    "time1 = np.empty(niter)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define functions for updating blocks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def updateV(U1,U2,W,b,rho,gamma): \n",
    "    _, d = W.size()\n",
    "    I = torch.eye(d, device=device)\n",
    "    U1 = nn.ReLU()(U1)\n",
    "    _, col_U2 = U2.size()\n",
    "    Vstar = torch.mm(torch.inverse(rho*(torch.mm(torch.t(W),W))+gamma*I), rho*torch.mm(torch.t(W),U2-b.repeat(1,col_U2))+gamma*U1)\n",
    "    return Vstar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def updateWb(U, V, W, b, alpha, rho): \n",
    "    d,N = V.size()\n",
    "    I = torch.eye(d, device=device)\n",
    "    _, col_U = U.size()\n",
    "    Wstar = torch.mm(alpha*W+rho*torch.mm(U-b.repeat(1,col_U),torch.t(V)),torch.inverse(alpha*I+rho*(torch.mm(V,torch.t(V)))))\n",
    "    bstar = (alpha*b+rho*torch.sum(U-torch.mm(W,V), dim=1).reshape(b.size()))/(rho*N+alpha)\n",
    "    return Wstar, bstar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define the proximal operator of the ReLU activation function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def relu_prox(a, b, gamma, d, N):\n",
    "    val = torch.empty(d,N, device=device)\n",
    "    x = (a+gamma*b)/(1+gamma)\n",
    "    y = torch.min(b,torch.zeros(d,N, device=device))\n",
    "\n",
    "    val = torch.where(a+gamma*b < 0, y, torch.zeros(d,N, device=device))\n",
    "    val = torch.where(((a+gamma*b >= 0) & (b >=0)) | ((a*(gamma-np.sqrt(gamma*(gamma+1))) <= gamma*b) & (b < 0)), x, val)\n",
    "    val = torch.where((-a <= gamma*b) & (gamma*b <= a*(gamma-np.sqrt(gamma*(gamma+1)))), b, val)\n",
    "    return val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples, validate on 10000 samples\n",
      "Epoch 1 / 50 \n",
      " - time: 2.048199415206909 - sq_loss: 20286.423828125 - tot_loss: 20315.255833353847 - acc: 0.13381666666666667 - val_acc: 0.1318\n",
      "Epoch 2 / 50 \n",
      " - time: 1.8899121284484863 - sq_loss: 15321.1953125 - tot_loss: 15333.326430269517 - acc: 0.7147833333333333 - val_acc: 0.723\n",
      "Epoch 3 / 50 \n",
      " - time: 1.895658254623413 - sq_loss: 11774.9052734375 - tot_loss: 11782.7634959016 - acc: 0.8732666666666666 - val_acc: 0.8779\n",
      "Epoch 4 / 50 \n",
      " - time: 1.9003705978393555 - sq_loss: 9107.623046875 - tot_loss: 9114.166835594922 - acc: 0.9025666666666666 - val_acc: 0.9066\n",
      "Epoch 5 / 50 \n",
      " - time: 1.9022879600524902 - sq_loss: 7059.787109375 - tot_loss: 7065.728549709544 - acc: 0.9131 - val_acc: 0.9143\n",
      "Epoch 6 / 50 \n",
      " - time: 1.9122953414916992 - sq_loss: 5475.97607421875 - tot_loss: 5481.722796622664 - acc: 0.9197166666666666 - val_acc: 0.9194\n",
      "Epoch 7 / 50 \n",
      " - time: 2.066347360610962 - sq_loss: 4247.99658203125 - tot_loss: 4253.339897958562 - acc: 0.92385 - val_acc: 0.9238\n",
      "Epoch 8 / 50 \n",
      " - time: 2.0772886276245117 - sq_loss: 3295.244873046875 - tot_loss: 3300.5409899950027 - acc: 0.928 - val_acc: 0.9278\n",
      "Epoch 9 / 50 \n",
      " - time: 2.1313648223876953 - sq_loss: 2555.9765625 - tot_loss: 2560.868489332497 - acc: 0.9308166666666666 - val_acc: 0.9303\n",
      "Epoch 10 / 50 \n",
      " - time: 2.151984930038452 - sq_loss: 1982.4171142578125 - tot_loss: 1987.1164524555206 - acc: 0.9341166666666667 - val_acc: 0.9321\n",
      "Epoch 11 / 50 \n",
      " - time: 2.027393341064453 - sq_loss: 1537.486083984375 - tot_loss: 1542.280776258558 - acc: 0.93625 - val_acc: 0.9348\n",
      "Epoch 12 / 50 \n",
      " - time: 2.025909662246704 - sq_loss: 1192.377197265625 - tot_loss: 1196.7209557630122 - acc: 0.93835 - val_acc: 0.9368\n",
      "Epoch 13 / 50 \n",
      " - time: 1.9836673736572266 - sq_loss: 924.717041015625 - tot_loss: 929.4000244922936 - acc: 0.94 - val_acc: 0.9388\n",
      "Epoch 14 / 50 \n",
      " - time: 2.0575103759765625 - sq_loss: 717.1383056640625 - tot_loss: 722.9542279615998 - acc: 0.9420666666666667 - val_acc: 0.9407\n",
      "Epoch 15 / 50 \n",
      " - time: 1.9633827209472656 - sq_loss: 556.159423828125 - tot_loss: 560.5885963886976 - acc: 0.94365 - val_acc: 0.9413\n",
      "Epoch 16 / 50 \n",
      " - time: 2.0240511894226074 - sq_loss: 431.3221740722656 - tot_loss: 435.4561728797853 - acc: 0.94445 - val_acc: 0.943\n",
      "Epoch 17 / 50 \n",
      " - time: 1.983823537826538 - sq_loss: 334.513427734375 - tot_loss: 339.42329306155443 - acc: 0.9458 - val_acc: 0.943\n",
      "Epoch 18 / 50 \n",
      " - time: 1.9093897342681885 - sq_loss: 259.4398498535156 - tot_loss: 263.2498557753861 - acc: 0.9466666666666667 - val_acc: 0.9439\n",
      "Epoch 19 / 50 \n",
      " - time: 1.9240086078643799 - sq_loss: 201.22213745117188 - tot_loss: 205.00519214570522 - acc: 0.9478 - val_acc: 0.9438\n",
      "Epoch 20 / 50 \n",
      " - time: 1.9286179542541504 - sq_loss: 156.07469177246094 - tot_loss: 159.72269743308425 - acc: 0.9483333333333334 - val_acc: 0.9451\n",
      "Epoch 21 / 50 \n",
      " - time: 1.919344186782837 - sq_loss: 121.06202697753906 - tot_loss: 124.85546385124326 - acc: 0.9493166666666667 - val_acc: 0.9459\n",
      "Epoch 22 / 50 \n",
      " - time: 1.923945426940918 - sq_loss: 93.90869903564453 - tot_loss: 97.49829767085612 - acc: 0.9497833333333333 - val_acc: 0.9458\n",
      "Epoch 23 / 50 \n",
      " - time: 1.9155528545379639 - sq_loss: 72.8500747680664 - tot_loss: 76.88765461742878 - acc: 0.9504333333333334 - val_acc: 0.9464\n",
      "Epoch 24 / 50 \n",
      " - time: 1.9297609329223633 - sq_loss: 56.51704406738281 - tot_loss: 60.39970852248371 - acc: 0.9509833333333333 - val_acc: 0.947\n",
      "Epoch 25 / 50 \n",
      " - time: 1.9225642681121826 - sq_loss: 43.84898376464844 - tot_loss: 47.92639542743564 - acc: 0.9517 - val_acc: 0.9465\n",
      "Epoch 26 / 50 \n",
      " - time: 1.9234802722930908 - sq_loss: 34.02302932739258 - tot_loss: 37.62718463782221 - acc: 0.9521333333333334 - val_acc: 0.9469\n",
      "Epoch 27 / 50 \n",
      " - time: 1.9231202602386475 - sq_loss: 26.40119171142578 - tot_loss: 30.03031969256699 - acc: 0.9525 - val_acc: 0.9474\n",
      "Epoch 28 / 50 \n",
      " - time: 1.9119796752929688 - sq_loss: 20.48874855041504 - tot_loss: 23.98542689345777 - acc: 0.9528666666666666 - val_acc: 0.9475\n",
      "Epoch 29 / 50 \n",
      " - time: 1.9149060249328613 - sq_loss: 15.901873588562012 - tot_loss: 19.341799194458872 - acc: 0.9532 - val_acc: 0.9479\n",
      "Epoch 30 / 50 \n",
      " - time: 1.9200165271759033 - sq_loss: 12.343406677246094 - tot_loss: 15.94023716589436 - acc: 0.9533166666666667 - val_acc: 0.9471\n",
      "Epoch 31 / 50 \n",
      " - time: 1.9213383197784424 - sq_loss: 9.582348823547363 - tot_loss: 13.07723309006542 - acc: 0.954 - val_acc: 0.9474\n",
      "Epoch 32 / 50 \n",
      " - time: 1.9484777450561523 - sq_loss: 7.440096378326416 - tot_loss: 11.300622095819563 - acc: 0.9539333333333333 - val_acc: 0.9467\n",
      "Epoch 33 / 50 \n",
      " - time: 1.930436611175537 - sq_loss: 5.777571201324463 - tot_loss: 9.599534923676401 - acc: 0.95445 - val_acc: 0.9475\n",
      "Epoch 34 / 50 \n",
      " - time: 1.9378283023834229 - sq_loss: 4.487327575683594 - tot_loss: 8.042589909397066 - acc: 0.95415 - val_acc: 0.947\n",
      "Epoch 35 / 50 \n",
      " - time: 1.931406021118164 - sq_loss: 3.485767126083374 - tot_loss: 6.666785731445998 - acc: 0.95435 - val_acc: 0.9466\n",
      "Epoch 36 / 50 \n",
      " - time: 1.9388427734375 - sq_loss: 2.7083497047424316 - tot_loss: 6.498393258079886 - acc: 0.9543833333333334 - val_acc: 0.9463\n",
      "Epoch 37 / 50 \n",
      " - time: 1.9123201370239258 - sq_loss: 2.1048216819763184 - tot_loss: 6.314744739793241 - acc: 0.9546 - val_acc: 0.9465\n",
      "Epoch 38 / 50 \n",
      " - time: 1.939476728439331 - sq_loss: 1.6361708641052246 - tot_loss: 5.13390444801189 - acc: 0.9545666666666667 - val_acc: 0.9469\n",
      "Epoch 39 / 50 \n",
      " - time: 1.9159090518951416 - sq_loss: 1.272296667098999 - tot_loss: 4.714561613276601 - acc: 0.9545166666666667 - val_acc: 0.947\n",
      "Epoch 40 / 50 \n",
      " - time: 1.9141647815704346 - sq_loss: 0.9896255135536194 - tot_loss: 5.235138806281611 - acc: 0.95485 - val_acc: 0.9462\n",
      "Epoch 41 / 50 \n",
      " - time: 1.9211070537567139 - sq_loss: 0.7700644135475159 - tot_loss: 4.280116225942038 - acc: 0.9549166666666666 - val_acc: 0.9461\n",
      "Epoch 42 / 50 \n",
      " - time: 1.9144577980041504 - sq_loss: 0.59946608543396 - tot_loss: 3.938017624663189 - acc: 0.9548166666666666 - val_acc: 0.9461\n",
      "Epoch 43 / 50 \n",
      " - time: 1.9543743133544922 - sq_loss: 0.4668845534324646 - tot_loss: 3.4499458639184013 - acc: 0.9547666666666667 - val_acc: 0.946\n",
      "Epoch 44 / 50 \n",
      " - time: 2.736018657684326 - sq_loss: 0.3638189136981964 - tot_loss: 3.644937446573749 - acc: 0.9547666666666667 - val_acc: 0.9452\n",
      "Epoch 45 / 50 \n",
      " - time: 3.0279643535614014 - sq_loss: 0.2836720049381256 - tot_loss: 3.7486055734334514 - acc: 0.9548 - val_acc: 0.9456\n",
      "Epoch 46 / 50 \n",
      " - time: 3.106459140777588 - sq_loss: 0.22133870422840118 - tot_loss: 5.305122142075561 - acc: 0.95475 - val_acc: 0.9448\n",
      "Epoch 47 / 50 \n",
      " - time: 3.0520179271698 - sq_loss: 0.1728452891111374 - tot_loss: 4.659826255287044 - acc: 0.9543666666666667 - val_acc: 0.9455\n",
      "Epoch 48 / 50 \n",
      " - time: 3.043464422225952 - sq_loss: 0.1350969821214676 - tot_loss: 3.597071212134324 - acc: 0.9544833333333334 - val_acc: 0.9448\n",
      "Epoch 49 / 50 \n",
      " - time: 3.091150999069214 - sq_loss: 0.1057085320353508 - tot_loss: 3.2851635885890573 - acc: 0.9543166666666667 - val_acc: 0.945\n",
      "Epoch 50 / 50 \n",
      " - time: 2.9529201984405518 - sq_loss: 0.08281848579645157 - tot_loss: 3.289984852191992 - acc: 0.9541833333333334 - val_acc: 0.9439\n"
     ]
    }
   ],
   "source": [
    "# Iterations\n",
    "print('Train on', N, 'samples, validate on', N_test, 'samples')\n",
    "for k in range(niter):\n",
    "    start = time.time()\n",
    "\n",
    "    # update V4\n",
    "    V4 = (y_one_hot + gamma4*U4 + alpha1*V4)/(1 + gamma4 + alpha1)\n",
    "    \n",
    "    # update U4 \n",
    "    U4 = (gamma4*V4 + rho4*(torch.mm(W4,V3) + b4.repeat(1,N)))/(gamma4 + rho4)\n",
    "\n",
    "    # update W4 and b4\n",
    "    W4, b4 = updateWb(U4,V3,W4,b4,alpha2,rho4)\n",
    "    \n",
    "    # update V3\n",
    "    V3 = updateV(U3,U4,W4,b4,rho4,gamma3)\n",
    "    \n",
    "    # update U3\n",
    "    U3 = relu_prox(V3,(rho3*torch.addmm(b3.repeat(1,N), W3, V2) + alpha3*U3)/(rho3 + alpha3),(rho3 + alpha3)/gamma3,d3,N)\n",
    "    \n",
    "    # update W3 and b3\n",
    "    W3, b3 = updateWb(U3,V2,W3,b3,alpha4,rho3)\n",
    "    \n",
    "    # update V2\n",
    "    V2 = updateV(U2,U3,W3,b3,rho3,gamma2)\n",
    "    \n",
    "    # update U2\n",
    "    U2 = relu_prox(V2,(rho2*torch.addmm(b2.repeat(1,N), W2, V1) + alpha5*U2)/(rho2 + alpha5),(rho2 + alpha5)/gamma2,d2,N)\n",
    "    \n",
    "    # update W2 and b2\n",
    "    W2, b2 = updateWb(U2,V1,W2,b2,alpha6,rho2)\n",
    "\n",
    "    # update V1\n",
    "    V1 = updateV(U1,U2,W2,b2,rho2,gamma1)\n",
    "    \n",
    "    # update U1\n",
    "    U1 = relu_prox(V1,(rho1*torch.addmm(b1.repeat(1,N), W1, x_train) + alpha7*U1)/(rho1 + alpha7),(rho1 + alpha7)/gamma1,d1,N)\n",
    "\n",
    "    # update W1 and b1\n",
    "    W1, b1 = updateWb(U1,x_train,W1,b1,alpha8,rho1)\n",
    "\n",
    "    a1_train = nn.ReLU()(torch.addmm(b1.repeat(1, N), W1, x_train))\n",
    "    a2_train = nn.ReLU()(torch.addmm(b2.repeat(1, N), W2, a1_train))\n",
    "    a3_train = nn.ReLU()(torch.addmm(b3.repeat(1, N), W3, a2_train))\n",
    "    #print(torch.addmm(b4.repeat(1, N), W4, a3_train))\n",
    "    pred = torch.argmax(torch.addmm(b4.repeat(1, N), W4, a3_train), dim=0)\n",
    "\n",
    "    a1_test = nn.ReLU()(torch.addmm(b1.repeat(1, N_test), W1, x_test))\n",
    "    a2_test = nn.ReLU()(torch.addmm(b2.repeat(1, N_test), W2, a1_test))\n",
    "    a3_test = nn.ReLU()(torch.addmm(b3.repeat(1, N_test), W3, a2_test))\n",
    "    pred_test = torch.argmax(torch.addmm(b4.repeat(1, N_test), W4, a3_test), dim=0)\n",
    "    \n",
    "    loss1[k] = gamma4/2*torch.pow(torch.dist(V4,y_one_hot,2),2).cpu().numpy()\n",
    "    loss2[k] = loss1[k] + rho1/2*torch.pow(torch.dist(torch.addmm(b1.repeat(1,N), W1, x_train),U1,2),2).cpu().numpy() \\\n",
    "    +rho2/2*torch.pow(torch.dist(torch.addmm(b2.repeat(1,N), W2, V1),U2,2),2).cpu().numpy() \\\n",
    "    +rho3/2*torch.pow(torch.dist(torch.addmm(b3.repeat(1,N), W3, V2),U3,2),2).cpu().numpy() \\\n",
    "    +rho4/2*torch.pow(torch.dist(torch.addmm(b4.repeat(1,N), W4, V3),U4,2),2).cpu().numpy()\n",
    "    \n",
    "    # compute training accuracy\n",
    "    correct_train = pred == y_train\n",
    "    accuracy_train[k] = np.mean(correct_train.cpu().numpy())\n",
    "    \n",
    "    # compute validation accuracy\n",
    "    correct_test = pred_test == y_test\n",
    "    accuracy_test[k] = np.mean(correct_test.cpu().numpy())\n",
    "    \n",
    "    # compute training time\n",
    "    stop = time.time()\n",
    "    duration = stop - start\n",
    "    time1[k] = duration\n",
    "    \n",
    "    # print results\n",
    "    print('Epoch', k + 1, '/', niter, '\\n', \n",
    "          '-', 'time:', time1[k], '-', 'sq_loss:', loss1[k], '-', 'tot_loss:', loss2[k], \n",
    "          '-', 'acc:', accuracy_train[k], '-', 'val_acc:', accuracy_test[k])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization of training results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5,1,'validation accuracy')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEICAYAAAC0+DhzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl8HfV57/HPV6slb5JseZWMDBiM2UwQxm2WkhCMSUigTdpAUnBbWjcp9JX0Jm1J0nuhadOQ9CZp6KW0LL5AQyAkQHB7nYBDSCgNi2VwsM0Sy8ZgeZX33bKk5/5xRvigxZK1Hemc7/v1Oq8z88xvznnGCD2a329mfooIzMzM0uVlOgEzMxt6XBzMzKwDFwczM+vAxcHMzDpwcTAzsw5cHMzMrAMXBzNA0r9K+p/93fYEc6iRFJIK+vuzzU6UfJ+DDXeS1gN/HBE/zXQufSGpBngDKIyI5sxmY7nOZw6W9fyXuNmJc3GwYU3SvwPTgP+QtF/SX6V1z1wn6S3gZ0nbH0jaImmPpKclnZn2OfdI+vtk+SJJDZI+L2mbpM2S/rCXbcdJ+g9JeyUtk/T3kp7p4bFNkbRY0k5J9ZL+JG3bHEl1yedulfStJD5C0ncl7ZC0O/nOiX36R7ac5OJgw1pEXAO8BXwkIkZFxDfSNv8WcAZwabL+Y2AGMAF4Ebj/OB89CRgLTAWuA26TVN6LtrcBB5I2C5JXTz0ANABTgI8D/yDp4mTbd4DvRMQY4BTgoSS+IMmlGhgHfBo4dALfaQa4OFh2uzkiDkTEIYCIWBQR+yLiCHAzcK6ksV3sexT4SkQcjYglwH7g9BNpKykf+BhwU0QcjIhXgHt7krikauA9wF9HxOGIWAHcBVyT9p2nShofEfsj4rm0+Djg1IhoiYjlEbG3J99pls7FwbLZhrYFSfmSbpG0VtJeYH2yaXwX++5oNyh8EBh1gm0rgYL0PNotH88UYGdE7EuLvUnq7ARSZyinAa8lXUeXJ/F/Bx4HHpS0SdI3JBX28DvN3ubiYNmgq0vu0uOfBK4APkiq26UmiWvg0qIRaAaq0mLVPdx3E1AhaXRabBqwESAi1kTE1aS6yL4O/FDSyOTs5W8jYhbwm8DlwLV9PA7LQS4Olg22Aid302Y0cATYAZQC/zDQSUVEC/AIcLOkUkkz6eEv6ojYAPwS+FoyyHwOqbOF+wEk/b6kyohoBXYnu7VIer+ks5Murb2kupla+vfILBe4OFg2+BrwN8nVOV/oos19pLplNgKvAM910a6/3UDqTGULqS6fB0gVqZ64mtQZzibgUVJjF0uTbfOB1ZL2kxqcvioiDpMa+P4hqcLwKvAL4Lv9ciSWU3wTnNkgkvR1YFJEnMhVS2aDzmcOZgNI0kxJ5yhlDqmuoUcznZdZd3znqNnAGk2qK2kKsA34JvBYRjMy6wF3K5mZWQfuVjIzsw6GbbfS+PHjo6amJtNpmJkNK8uXL98eEZXdtRu2xaGmpoa6urpMp2FmNqxIerMn7dytZGZmHbg4mJlZBy4OZmbWgYuDmZl14OJgZmYduDiYmVkHLg5mZtZBzhWHH720ke8+16PLfM3Mcla3xUFStaSnJL0qabWkzybxCklLJa1J3suTuCTdKqle0suS3pX2WQuS9mskLUiLny9pZbLPrZIGbHauJSs3c+8v1w/Ux5uZZYWenDk0A5+PiDOAucD1kmYBNwJPRsQM4MlkHeAyYEbyWgjcDqliAtwEXAjMAW5qKyhJm4Vp+83v+6F1rrqilIZdh/ADB83MutZtcYiIzRHxYrK8j9TsUlNJzcd7b9LsXuDKZPkK4L5IeQ4okzQZuBRYGhE7I2IXsBSYn2wbExHPRuo39n1pn9XvqspLOHS0hR0HmgbqK8zMhr0TGnOQVAOcBzwPTIyIzZAqIKQmOodU4diQtltDEjtevKGTeGffv1BSnaS6xsbGE0n9bdXlpQBs2HmwV/ubmeWCHhcHSaOAh4HPRcTe4zXtJBa9iHcMRtwREbURUVtZ2e1DBTtVXZEUh12HerW/mVku6FFxkFRIqjDcHxGPJOGtSZcQyfu2JN4AVKftXkVqgvTjxas6iQ+IqvISwGcOZmbH05OrlQTcDbwaEd9K27QYaLviaAHHpj5cDFybXLU0F9iTdDs9DsyTVJ4MRM8DHk+27ZM0N/muaxnAaRRHFhcwbmQRDbtcHMzMutKT+RzeDVwDrJS0Iol9CbgFeEjSdcBbwO8m25YAHwLqgYPAHwJExE5JfwcsS9p9JSJ2JsufAe4BSoAfJ68BU1VeQoO7lczMutRtcYiIZ+h8XADg4k7aB3B9F5+1CFjUSbwOOKu7XPpLVUUpqzfuGayvMzMbdnLuDmlIXbG0cfchWlp9r4OZWWdyszhUlHC0Jdi693CmUzEzG5Jyszj4Xgczs+PKyeLQdjmrB6XNzDqXk8VhankJEmzw5axmZp3KyeJQXJDPxNEj2LDTZw5mZp3JyeIAqUFpnzmYmXUud4tDeSkNHpA2M+tUzhaHqvIStuw9TFNza6ZTMTMbcnK3OFSU0hqweY/HHczM2svZ4nDsXgcXBzOz9nK3OFQkj+72oLSZWQc5WxwmjRlBQZ58l7SZWSdytjgU5OcxuWyE75I2M+tEzhYHSI07uFvJzKyjnswEt0jSNkmr0mLfl7Qiea1vmwRIUo2kQ2nb/jVtn/MlrZRUL+nWZNY3JFVIWippTfJePhAH2pnq8lIPSJuZdaInZw73APPTAxHxiYiYHRGzSc0t/Uja5rVt2yLi02nx24GFwIzk1faZNwJPRsQM4MlkfVBUV5Swff8RDjW1DNZXmpkNC90Wh4h4GtjZ2bbkr//fAx443mdImgyMiYhnk5ni7gOuTDZfAdybLN+bFh9wVcnlrJ5P2szsnfo65vBeYGtErEmLTZf0kqRfSHpvEpsKNKS1aUhiABMjYjNA8j6hjzn1WNvlrB6UNjN7p27nkO7G1bzzrGEzMC0idkg6H/iRpDPpfA7qE56jU9JCUl1TTJs2rRfpvtPbN8L5zMHM7B16feYgqQD4HeD7bbGIOBIRO5Ll5cBa4DRSZwpVabtXAZuS5a1Jt1Nb99O2rr4zIu6IiNqIqK2srOxt6m+rHF1McUGe73UwM2unL91KHwRei4i3u4skVUrKT5ZPJjXwvC7pLtonaW4yTnEt8Fiy22JgQbK8IC0+4CRRVV7iK5bMzNrpyaWsDwDPAqdLapB0XbLpKjoORL8PeFnSr4AfAp+OiLbB7M8AdwH1pM4ofpzEbwEukbQGuCRZHzRVvtfBzKyDbsccIuLqLuJ/0EnsYVKXtnbWvg44q5P4DuDi7vIYKNUVJazYsDtTX29mNiTl9B3SkBqU3nPoKHsPH810KmZmQ4aLQ0Xbo7vdtWRm1sbFwfM6mJl1kPPFoaq87UY4nzmYmbXJ+eJQVlrIqOIC3yVtZpYm54vDsXsdfOZgZtYm54sDpAalfa+DmdkxLg4cm9ch9cBYMzNzcSA1KH3oaAs7DjRlOhUzsyHBxYFj9zp4UNrMLMXFgWPzOnhQ2swsxcUBz+tgZtaeiwMwsriAipFFvkvazCzh4pCoKi/xXdJmZgkXh0R1eakHpM3MEi4OiaqKEjbuOkRLq+91MDPryUxwiyRtk7QqLXazpI2SViSvD6Vt+6KkekmvS7o0LT4/idVLujEtPl3S85LWSPq+pKL+PMCemjFhNE0trazfcSATX29mNqT05MzhHmB+J/FvR8Ts5LUEQNIsUtOHnpns8y+S8pN5pW8DLgNmAVcnbQG+nnzWDGAXcF37LxoMMyeNBuC1zfsy8fVmZkNKt8UhIp4GdnbXLnEF8GBEHImIN0jNFz0nedVHxLqIaAIeBK6QJOADpOabBrgXuPIEj6FfnDphFPl54rUtezPx9WZmQ0pfxhxukPRy0u1UnsSmAhvS2jQksa7i44DdEdHcLt4pSQsl1Umqa2xs7EPqHY0ozOfk8SN51WcOZma9Lg63A6cAs4HNwDeTuDppG72Idyoi7oiI2oioraysPLGMe2Dm5DE+czAzo5fFISK2RkRLRLQCd5LqNoLUX/7VaU2rgE3HiW8HyiQVtItnxMxJo2nYdYi9h49mKgUzsyGhV8VB0uS01d8G2q5kWgxcJalY0nRgBvACsAyYkVyZVERq0HpxpJ6R/RTw8WT/BcBjvcmpP5wxOTUo/est7loys9xW0F0DSQ8AFwHjJTUANwEXSZpNqgtoPfCnABGxWtJDwCtAM3B9RLQkn3MD8DiQDyyKiNXJV/w18KCkvwdeAu7ut6M7QTMnjQHg1S37qK2pyFQaZmYZ121xiIirOwl3+Qs8Ir4KfLWT+BJgSSfxdRzrlsqoyWNHMGZEAa9t9riDmeU23yGdRlIyKO1uJTPLbS4O7ZwxaTSvb9lHqx+jYWY5zMWhnZmTx7D/SDMbd/shfGaWu1wc2ml7jMarHncwsxzm4tDOaRNHI+FxBzPLaS4O7YwsLuCkilLfKW1mOc3FoRMzJ43xM5bMLKe5OHRi5uTRrN9xgINNzd03NjPLQi4OnZg5aQwR8Out+zOdiplZRrg4dKLtGUu+U9rMcpWLQyeqy0spLcr3FUtmlrNcHDqRlydOnzTa9zqYWc5ycejCzEmpZyylnipuZpZbXBy6cMbk0ew5dJQtew9nOhUzs0Hn4tCFtrkdXvP9DmaWg7otDpIWSdomaVVa7B8lvSbpZUmPSipL4jWSDklakbz+NW2f8yWtlFQv6VZJSuIVkpZKWpO8lw/EgZ6o09ueseQ7pc0sB/XkzOEeYH672FLgrIg4B/g18MW0bWsjYnby+nRa/HZgIampQ2ekfeaNwJMRMQN4MlnPuLElhUwtK/GZg5nlpG6LQ0Q8DexsF3siItpuH34OqDreZyRzTo+JiGeTeaPvA65MNl8B3Jss35sWz7iZk0b7GUtmlpP6Y8zhj4Afp61Pl/SSpF9Iem8Smwo0pLVpSGIAEyNiM0DyPqGrL5K0UFKdpLrGxsZ+SP34Zk4ezdrGAxxpbhnw7zIzG0r6VBwkfRloBu5PQpuBaRFxHvA/gO9JGgOok91P+BrRiLgjImojoraysrK3affYzEljaGkN6rf5MRpmllt6XRwkLQAuBz6VdBUREUciYkeyvBxYC5xG6kwhveupCtiULG9Nup3aup+29Tan/nbsMRoedzCz3NKr4iBpPvDXwEcj4mBavFJSfrJ8MqmB53VJd9E+SXOTq5SuBR5LdlsMLEiWF6TFM65m3EiKCvI87mBmOaeguwaSHgAuAsZLagBuInV1UjGwNLki9bnkyqT3AV+R1Ay0AJ+OiLbB7M+QuvKphNQYRds4xS3AQ5KuA94CfrdfjqwfFOTncdrEUX7GkpnlnG6LQ0Rc3Un47i7aPgw83MW2OuCsTuI7gIu7yyNTzpg0hqdeH/jBbzOzocR3SHdj5uQxbN9/hMZ9RzKdipnZoHFx6MasyanHaKzauCfDmZiZDR4Xh27Mri6jIE+8sH5n943NzLKEi0M3SoryObtqLMvecHEws9zh4tADc2oq+FXDbg4f9Z3SZpYbXBx64IKaCo62BCs27M50KmZmg8LFoQdqa1JPEXfXkpnlCheHHigrLWLmpNEelDaznOHi0EMX1FTw4pu7aG5pzXQqZmYDzsWhhy6YXsGBphZe2eznLJlZ9nNx6KE5NRUAvOBxBzPLAS4OPTRp7AimVZSyzOMOZpYDXBxOwAU1FSxbv4tk+gozs6zl4nAC5kwvZ+eBJtY2emY4M8tuLg4n4IK3xx12ZTgTM7OB1aPiIGmRpG2SVqXFKiQtlbQmeS9P4pJ0q6R6SS9LelfaPguS9muSaUbb4udLWpnsc2syW9yQM338SMaPKvK4g5llvZ6eOdwDzG8XuxF4MiJmAE8m6wCXkZoedAawELgdUsWE1CxyFwJzgJvaCkrSZmHafu2/a0iQxJzpFb5iycyyXo+KQ0Q8DbT/jXgFcG+yfC9wZVr8vkh5DiiTNBm4FFgaETsjYhewFJifbBsTEc9GaqT3vrTPGnIuqKlg4+5DbNx9KNOpmJkNmL6MOUyMiM0AyfuEJD4V2JDWriGJHS/e0El8SGobd/Bzlswsmw3EgHRn4wXRi3jHD5YWSqqTVNfYmJl5nc+YPIbRxQV+zpKZZbW+FIetSZcQyfu2JN4AVKe1qwI2dROv6iTeQUTcERG1EVFbWVnZh9R7Lz9PnF9T7nEHM8tqfSkOi4G2K44WAI+lxa9NrlqaC+xJup0eB+ZJKk8GoucBjyfb9kmam1yldG3aZw1JF9RUUL9tPzsPNGU6FTOzAdHTS1kfAJ4FTpfUIOk64BbgEklrgEuSdYAlwDqgHrgT+DOAiNgJ/B2wLHl9JYkBfAa4K9lnLfDjvh/awJkzPRl3cNeSmWWpgp40ioiru9h0cSdtA7i+i89ZBCzqJF4HnNWTXIaCc6rGUlSQx7I3dnLpmZMynY6ZWb/zHdK9UFyQz+zqMg9Km1nWcnHopTk1FazetJcDR5oznYqZWb9zceilC6ZX0NIavPiWn7NkZtnHxaGXzj+pnMJ88V9rtmc6FTOzfufi0EujiguYe/I4nli9xfM7mFnWcXHog3lnTmL9joPUb/P8DmaWXVwc+uCSMyYC8MQrWzOciZlZ/3Jx6INJY0dwbnUZT6zekulUzMz6lYtDH82bNZFfNexhy57DmU7FzKzfuDj00bxZqa6lpa+6a8nMsoeLQx+dOmEU08ePdNeSmWUVF4c+ksS8WRN5bt0O9h4+mul0zMz6hYtDP5h35kSOtgQ/fz0zExCZmfU3F4d+MLu6nPGjity1ZGZZw8WhH+TniQ+eMZGfv97IkeaWTKdjZtZnLg79ZN6ZE9l/pJln1+7IdCpmZn3W6+Ig6XRJK9JeeyV9TtLNkjamxT+Uts8XJdVLel3SpWnx+UmsXtKNfT2oTPjNU8ZTWpTPUt8tbWZZoNfFISJej4jZETEbOB84CDyabP5227aIWAIgaRZwFXAmMB/4F0n5kvKB24DLgFnA1UnbYWVEYT4XnV7J0le20trqB/GZ2fDWX91KFwNrI+LN47S5AngwIo5ExBuk5ouek7zqI2JdRDQBDyZth51LZk1k274j/Kphd6ZTMTPrk/4qDlcBD6St3yDpZUmLJJUnsanAhrQ2DUmsq3gHkhZKqpNU19g49C4b/cDpE8nPkx/EZ2bDXp+Lg6Qi4KPAD5LQ7cApwGxgM/DNtqad7B7HiXcMRtwREbURUVtZWdmnvAfC2NJC5p5c4XEHMxv2+uPM4TLgxYjYChARWyOiJSJagTtJdRtB6oygOm2/KmDTceLD0rxZk6jftp+1jZ7jwcyGr/4oDleT1qUkaXLatt8GViXLi4GrJBVLmg7MAF4AlgEzJE1PzkKuStoOS5ckD+J7YrXPHsxs+OpTcZBUClwCPJIW/oaklZJeBt4P/AVARKwGHgJeAX4CXJ+cYTQDNwCPA68CDyVth6UpZSWcW13Gj17a6OlDzWzYKujLzhFxEBjXLnbNcdp/FfhqJ/ElwJK+5DKUXH1BNTc+spIX39rF+SdVZDodM7MT5jukB8BHzp3CqOIC7n/urUynYmbWKy4OA2BkcQFXnjeF/1y5md0HmzKdjpnZCXNxGCCfnHMSTc2tPPzixkynYmZ2wlwcBsisKWOYXV3G955/0wPTZjbsuDgMoE9eOI21jQd44Y2dmU7FzOyEuDgMoI+cM4XRIwr43gsemDaz4cXFYQCVFOXzsXdV8eOVW9h5wAPTZjZ8uDgMsE9eOI2mllYeXt6Q6VTMzHrMxWGAnTZxNLUnlfPAC295YNrMhg0Xh0HwyQunsW77AZ5d5ylEzWx4cHEYBB86ezJjSwr53vMemDaz4cHFYRCMKEwNTD++egvb9x/JdDpmZt1ycRgkn7ywmqMtwQ/qPDBtZkOfi8MgOXXCaC6cXsF3n3uTpubWTKdjZnZcLg6D6DMXncLG3Yf4ft2G7hubmWVQf8whvT6Z3GeFpLokViFpqaQ1yXt5EpekWyXVS3pZ0rvSPmdB0n6NpAV9zWso+q3TKrmgppz/87M1HD7akul0zMy61F9nDu+PiNkRUZus3wg8GREzgCeTdUjNNz0jeS0EbodUMQFuAi4kNef0TW0FJZtI4gvzTmfr3iP8+7NvZjodM7MuDVS30hXAvcnyvcCVafH7IuU5oCyZc/pSYGlE7IyIXcBSYP4A5ZZRF548jvfOGM/tv1jL/iPNmU7HzKxT/VEcAnhC0nJJC5PYxIjYDJC8T0jiU4H0DveGJNZV/B0kLZRUJ6musbGxH1LPjC/MO52dB5pY9MwbmU7FzKxT/VEc3h0R7yLVZXS9pPcdp606icVx4u8MRNwREbURUVtZWdm7bIeAc6vLmDdrInc+vc4zxZnZkNTn4hARm5L3bcCjpMYMtibdRSTv25LmDUB12u5VwKbjxLPW5+edzv6mZv7t6XWZTsXMrIM+FQdJIyWNblsG5gGrgMVA2xVHC4DHkuXFwLXJVUtzgT1Jt9PjwDxJ5clA9LwklrVOnzSaj547hXv+ez3b9h3OdDpmZu/Q1zOHicAzkn4FvAD8v4j4CXALcImkNcAlyTrAEmAdUA/cCfwZQETsBP4OWJa8vpLEstpffPA0mlpa+Zen1mY6FTOzd9BwfYx0bW1t1NXVZTqNPrvx4Zd55MWNPPWXFzG1rCTT6ZhZlpO0PO22gy75DukM+/OLZwBw60/XZDgTM7NjXBwybGpZCZ+aO40fvtjAyoY9mU7HzAxwcRgSPnfxaYwbWcQXfvArjjT7sRpmlnkuDkPA2NJCbvnY2by+dR///GR9ptMxM3NxGCo+MHMiHz+/itt/sZaXG3ZnOh0zy3EuDkPI/7x8FpWjivn8Q+5eMrPMcnEYQsaWFPK1j53Nmm37+SdfvWRmGeTiMMS8//QJfKK2mn/7xVpWbHD3kpllhovDEPTly89g4pgRfP6hFZ4UyMwywsVhCBozopCvf+wc1jYe4Ns//XWm0zGzHOTiMES977RKrp5TzZ1Pr+OX9dsznY6Z5RgXhyHsyx+exakTRvGn311O/bZ9mU7HzHKIi8MQNqq4gEV/cAHFBfn8wf9dRuO+I5lOycxyhIvDEFdVXsrdC2rZvv8If3xfHYeaPEBtZgPPxWEYOLe6jFuvOo+XG3bzF99fQWvr8HzMupkNH70uDpKqJT0l6VVJqyV9NonfLGmjpBXJ60Np+3xRUr2k1yVdmhafn8TqJd3Yt0PKTvPOnMTffHgWP1m9ha/9+NVMp2NmWa6gD/s2A5+PiBeTqUKXS1qabPt2RPzv9MaSZgFXAWcCU4CfSjot2XwbqRnjGoBlkhZHxCt9yC0r/dG7a3hrxwHu/K83mDZuJNfMPSnTKZlZlup1cUjmft6cLO+T9Cow9Ti7XAE8GBFHgDck1QNzkm31EbEOQNKDSVsXh3Yk8b8+ciYNuw5x02OrqBxVxPyzJmc6LTPLQv0y5iCpBjgPeD4J3SDpZUmLJJUnsanAhrTdGpJYV3HrRH6euPXq8zinqow/u/9F7n/+zUynZGZZqM/FQdIo4GHgcxGxF7gdOAWYTerM4pttTTvZPY4T7+y7Fkqqk1TX2NjY19SHrZHFBXzvTy7kfadV8uVHV/Htpb9muM4FbmZDU5+Kg6RCUoXh/oh4BCAitkZES0S0AndyrOuoAahO270K2HSceAcRcUdE1EZEbWVlZV9SH/ZKiwq489paPn5+Fd95cg1fenQVzS2tmU7LzLJEX65WEnA38GpEfCstnt4J/tvAqmR5MXCVpGJJ04EZwAvAMmCGpOmSikgNWi/ubV65pDA/j3/8+Dlc//5TeOCFt/jM/S/6QX1m1i/6crXSu4FrgJWSViSxLwFXS5pNqmtoPfCnABGxWtJDpAaam4HrI6IFQNINwONAPrAoIlb3Ia+cIom/vHQmE0aP4Ob/WM2n7nqeuxfUUlZalOnUzGwY03Dtq66trY26urpMpzGkLFm5mc89uIIJY4r51u/NZs70ikynZGZDjKTlEVHbXTvfIZ1FPnT2ZB5YOJc8iU/c8Sxf/8lrNDV7HMLMTpyLQ5Y5/6Rylnz2vXyitprbf76WK2/7b3691U90NbMT4+KQhUYVF3DLx87hjmvOZ+vew1z+z8+w6Jk3/EwmM+sxF4csNu/MSfzkc+/jvaeO5yv/+QqfvOs5Vm3ck+m0zGwYcHHIcpWji7lrQS23/M7ZvL5lH5f/8zN89sGXeGvHwUynZmZDmK9WyiF7Dx/l336xlrufeYOW1uBTF57En3/gVMaNKs50amY2SHp6tZKLQw7auvcw//TTX/P9ZRsoLSpg4ftO5pq5J1E+0vdGmGU7FwfrVv22fXzjJ6/zxCtbKSrI46PnTuHa3ziJc6rKMp2amQ0QFwfrsde37OO+Z9fz6EsbOdjUwrnVZVw79yQ+fM5kRhTmZzo9M+tHLg52wvYePsojyxu477k3Wdd4gIqRRcw/axKXnTWJuSePozDf1y+YDXcuDtZrEcF/1+/ggRfe4mevbePQ0RbKSgv54BkTueysSbxnxniKC3xGYTYc9bQ49OXBe5alJPGeGeN5z4zxHGpq4ek1jfxk1RYeX72FHy5vYFRxAb95yjh+I3mdNmE0eXmdTcthZsOVi4MdV0lRPpeeOYlLz5xEU3Mrv1y7ncdXb+GZ+u088cpWACpGFjH35Ap+45TxXFBTzqmVoyhwF5TZsObiYD1WVJDHRadP4KLTJwDQsOsgz67dwbPrdvDs2h0sWbkFgOKCPGZOHsPZU8dw1pSxnDV1LDMmjnJXlNkw4jEH6xcRwZs7DrJiw25WbdzDqk17WL1xL/uONAOpua+ry0uYPn4kJ1eOSt5HMn38SCaOHuFuKbNB4jEHG1SSqBk/kprxI7nyvKkAtLYGG3YdZNXGvby2ZS/rth9gXeMBnl23g8NHjz1KvDBfTB5bwpSyEUwtK2Vq2QimlJUwccwIxo8qZvzoIsaNLKaowF1VZoNlyBQHSfOB75CaDe6uiLglwylZH+XliZPGjeSkcSP58DnHZo9tbQ227D3MG9sP8Mb2A2zcfYhNuw+xcdchfrl2O1v3HqazB8iWlRZSOaqY8pFFlJUUUlbsxRmlAAAGd0lEQVRaSFlpEWOT5bElhYwqLmD0iAJGFRcyakQBo4pTr3yfmZidkCFRHCTlA7cBlwANwDJJiyPilcxmZgMhL09MKSthSlkJ7z51fIftR1ta2bLnMI37j7B93xG272+icd8Rtu9PvXYdbOKtnQd5ueEouw81veMspCtFBXmUFuVTWphPSVE+pUUFlBTlM6Iwn+KCvOSVT3FharmoII/i/NR7Yfp7fh4F+aIgP4/CvNR7Qb4ozMsjP0+pbXmiIG09TyI/T+RL5OWRtpzalifeXs6XkDgWT9ZTU7abDZ4hURyAOUB9RKwDkPQgcAWp+aYtxxTm51FdUUp1RWmP2h8+2sKeQ0fZc+go+480s/9w89vv+5L3g0ebOdTUwsGmluS9mYNNLew9dJQjza0cOdqSem9u4fDRVppaWofULHoSiFSREKmiwduxpIikbSe9fbJM2vZjtUZvL6fH1S7etm/7nNLfO+TM8Qtad59/otTlSrfhLr+7qzHZaLfQk5FbtVvoyXF3teXuBRcwbVzP/v/oraFSHKYCG9LWG4AL2zeStBBYCDBt2rTBycyGvBGFqTOAiWNG9OvnRgTNrUFTc6pQNLW0crSlleaWoLm1laMtQXNLcLQ1FWtpTcWbW4OWltS+za2ttLQGrRG0tKa61Foi1TYiaA3e3t7WJggiUm1bg7e3RRzbFqTiJMttn5Xepu0Y4u3jabctibWtHdsn1Y607emfkbR+54b2/3Y9+Lft7HO7+qzuykb67t3+Qu/xhkQ3habtl/vxcjx2nO887uN9dxxn42CMvw2V4tDZv2uHf5mIuAO4A1JXKw10UpbbJFGYLwrz8xjpp5pbjhkql380ANVp61XApgzlYmaW84ZKcVgGzJA0XVIRcBWwOMM5mZnlrCHRrRQRzZJuAB4ndSnroohYneG0zMxy1pAoDgARsQRYkuk8zMxs6HQrmZnZEOLiYGZmHbg4mJlZBy4OZmbWwbB9ZLekRuDNbpqNB7YPQjpDjY87t/i4c0tfj/ukiKjsrtGwLQ49IamuJ88tzzY+7tzi484tg3Xc7lYyM7MOXBzMzKyDbC8Od2Q6gQzxcecWH3duGZTjzuoxBzMz651sP3MwM7NecHEwM7MOsrI4SJov6XVJ9ZJuzHQ+A0nSIknbJK1Ki1VIWippTfJenskc+5ukaklPSXpV0mpJn03iWX3cAJJGSHpB0q+SY//bJD5d0vPJsX8/efR9VpGUL+klSf+ZrGf9MQNIWi9ppaQVkuqS2ID/rGddcZCUD9wGXAbMAq6WNCuzWQ2oe4D57WI3Ak9GxAzgyWQ9mzQDn4+IM4C5wPXJf+NsP26AI8AHIuJcYDYwX9Jc4OvAt5Nj3wVcl8EcB8pngVfT1nPhmNu8PyJmp93fMOA/61lXHIA5QH1ErIuIJuBB4IoM5zRgIuJpYGe78BXAvcnyvcCVg5rUAIuIzRHxYrK8j9QvjKlk+XEDRMr+ZLUweQXwAeCHSTzrjl1SFfBh4K5kXWT5MXdjwH/Ws7E4TAU2pK03JLFcMjEiNkPqFykwIcP5DBhJNcB5wPPkyHEn3SsrgG3AUmAtsDsimpMm2fgz/0/AXwGtyfo4sv+Y2wTwhKTlkhYmsQH/WR8yk/30I3US8/W6WUjSKOBh4HMRsTf1x2T2i4gWYLakMuBR4IzOmg1uVgNH0uXAtohYLumitnAnTbPmmNt5d0RskjQBWCrptcH40mw8c2gAqtPWq4BNGcolU7ZKmgyQvG/LcD79TlIhqcJwf0Q8koSz/rjTRcRu4Oekxl3KJLX9sZdtP/PvBj4qaT2pbuIPkDqTyOZjfltEbEret5H6Y2AOg/Czno3FYRkwI7mSoQi4Clic4ZwG22JgQbK8AHgsg7n0u6S/+W7g1Yj4VtqmrD5uAEmVyRkDkkqAD5Iac3kK+HjSLKuOPSK+GBFVEVFD6v/nn0XEp8jiY24jaaSk0W3LwDxgFYPws56Vd0hL+hCpvyzygUUR8dUMpzRgJD0AXETqMb5bgZuAHwEPAdOAt4DfjYj2g9bDlqT3AP8FrORYH/SXSI07ZO1xA0g6h9QAZD6pP+4eioivSDqZ1F/VFcBLwO9HxJHMZTowkm6lL0TE5blwzMkxPpqsFgDfi4ivShrHAP+sZ2VxMDOzvsnGbiUzM+sjFwczM+vAxcHMzDpwcTAzsw5cHMzMrAMXBzMz68DFwczMOvj/isxs1jIZtTMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAH9NJREFUeJzt3XmUZGd53/Hvc29t3T2bZpE0+0hCECQkJI4QOBJBYAwSJhL4OFgKigEDCkkwEMCJIGYHO8GxMTlWYrMFYhZZZp1DhNk5BsKikYUBjZA1iJGmNZJmNHt37fc++ePemqmpqZ6uGbq7+t76fc6pU3fre5+3uvtXb79VXa+5OyIiki/BsAsQEZG5p3AXEckhhbuISA4p3EVEckjhLiKSQwp3EZEcUrjLgjCzq8xssmv9bjO7apBjT+Naf2lmbz3drxfJg8KwC5DR5O4XzsV5zOxlwCvd/cquc796Ls4tkmXquYtkhJmpMyYDU7jLwMzsZjP7TM+2D5jZ/0iXX25m95jZETO738z+7UnOtdPMnpMuj5nZx8zsgJltB57a57q/SM+73cxelG5/IvCXwK+Z2ZSZHUy3f8zM3tP19a8ysx1mtt/MtprZuq59bmavNrP70uvfYmY2Q82Xm9n3zeygmT1sZn9hZqWu/Rea2dfS6zxqZm9Jt4dm9pauNtxpZhvNbEt6/ULXOb5tZq9Ml19mZt8zs/eb2X7gHWZ2npl908z2mdljZvZJM1vR9fUbzexzZrY3PeYvzKyc1nRR13FnmlnNzNbM9D2SbFO4y6n4NPB8M1sGSWgBLwY+le7fA7wAWAa8HHi/mT1lgPO+HTgvvT0PeGnP/l8AzwCWA+8EPmFma939HuDVwPfdfYm7r+j5Oszs2cAfp3WuBR4Abu057AUkTyhPTo973gx1RsB/BFYDvwb8OvDv0+ssBb4O/B2wDngc8I30694A3AA8n+Sx+T2gerIHpMvTgPuBM4H3Apa2Zx3wRGAj8I60hhD4UtrGLcB64FZ3b6RtvrHrvDcAX3f3vQPWIVnj7rrpNvAN+C7wu+nybwC/OMmxXwBely5fBUx27dsJPCddvh+4umvfTd3H9jnvj4Hr0uWXAd/t2f8x4D3p8keA93XtWwK0gC3pugNXdu2/Dbh5wMfi9cDn0+UbgLtmOO7eTr0927ek1y90bfs2yWsInbY9OEsNL+xcl+QJZ2/3+bqOexqwCwjS9W3Ai4f986Tb/N3Uc5dT9SmSIAP41xzrtWNm15jZD9IhgIMkPdXVA5xzHUnwdDzQvdPMftfMfpwOhxwEnjTgeTvnPno+d58C9pH0ajse6VqukjwBnMDMHm9mXzKzR8zsMPBHXXVsJPkLo5+T7ZtN9+PSGU651cweSmv4RE8ND7h7u/ck7v5DYBp4ppn9M5K/LLaeZk2SAQp3OVV/C1xlZhuAF5GGu5mVgc8C/x04y5MhkttJhhFm8zBJMHVs6iyY2WbgQ8BrgFXpeX/Wdd7ZPtZ0N7C563wTwCrgoQHq6vW/gJ8D57v7MuAtXXXsIhlW6memfdPp/XjXtrN7jult3x+n2y5Oa7ixp4ZNJ3nh9ePp8f8G+Iy712c4TnJA4S6nxJMx2m8D/xv4pSfj3gAloEwyLNA2s2uA5w542tuAN5vZGemTxu937ZsgCbO9kLxoS9Jz73gU2ND9wmaPTwEvN7NL0iegPwJ+6O47B6yt21LgMDCV9n7/Xde+LwFnm9nr0xcwl5rZ09J9HwbebWbnW+JiM1uVPpYPATemL7r+HjM/QXTXMAUcNLP1wB907fsRyRPlfzWzCTOrmNkVXfv/muQJ+Ubg/5xG+yVDFO5yOj4FPIeuIRl3PwK8liSoD5AM2Qz6Z/87SYZOfgl8lSSEOufdDvwp8H2SIL8I+F7X134TuBt4xMwe6z2xu38DeCvJXxUPk4Tn9QPW1etNJO06QvLXxN90XecIyWsQ/5JkmOc+4Fnp7j8jeVy+SvLk8BFgLN33KpKA3gdcCPy/WWp4J/AU4BDwf4HPddUQpdd/HPAgMAn8Ttf+SeAfSJ4sv3MK7ZYMMndN1iEyKszso8Bud//DYdci80v/FCEyIsxsC/BbwKXDrUQWgoZlREaAmb2b5IXoP3H3Xw67Hpl/GpYREckh9dxFRHJoaGPuq1ev9i1btgzr8iIimXTnnXc+5u6zfibQ0MJ9y5YtbNu2bViXFxHJJDN7YPajNCwjIpJLCncRkRxSuIuI5JDCXUQkhxTuIiI5pHAXEckhhbuISA7pg8NE5oi7Ezu045g4hsidKHbi2GnH6fRnQOyOe/K5u54uN6OYZju5tTrLUXx0pg4DOvN2G8m525HTipLj25En13UIzQgCIwwgMCMMjMCMOK0vqdOJ46SWo/Uf1xhwkuOTY48t9+qejeVojXZsnwNR7Mdu6fmiOLlG57FIH8Tuk6XtBsMI7Nh5u6/VdXif78mJ22bS+fowfczCwCgERhgEhEFyPT/ue5fUn7QzqS9IiiVIaw+CZF/n3IElyxetX87mVRODF3caFO6y6ESxU2tF1FsRtWZEox3RitIwi2Oi2I8GWifcGu2YVuRHwzG5dY6LaUae3h8L0UZPoLbj5JgoDePj75PAbsfH7+8sdwJcfnVmpxbKWfSeFz5J4S6LQxw71VbEdKPNkXqb6UabqfRWb0U0WjGNdkS9677Wiqg2I6rNNtVmEtTVZpt6K6YdJ+EbtVuUohqlaJpCXMNaDYK4QdlalGlRpkmRCMNxLOkxYUeXC8QUaVO0NiXayXJ6K1lyXyGiYhHloE1ojgdF3Ap4UMCDIgQFCEIK5hQtpmARRWIKFlOwNoUwSm60KRBR8DYhMQExgXnSu8QxgyCtLPCYkIjA0xtR8jgGZeKwSBSUicMSUVDGgyKBOSFOSExoMYEn5ycoEIclPCgRh2U8LBIHJYyYQtwk7Nyi5HEzj4nTtsVBEbeQ2IrEQQELChCGaXvT9aAAFkIQ4BZiQYhbmPSazTD3pA53Aosxj8FjLG5D3IY4Su/b4HGayvGxdHZP0rpQxgplrFDBihWCYiVZDwIguVbSve2651gPGaDn75jkUbdk3YMCFMoQVqBQxgtlCMsQljCO1WhxlNYcQVjECyUoVPCwjBcqEJbwOKLdahA3G8RRg7iV3uIYK6bnLZSwsJxcKyhAHOMeE3uMRxF48ldUHFaICmPEhTE8/TsmdjhzaXnef2cV7jnn7hyutzkw3WTfdJMDUw0OHzpAdfoQtVotudXr1Ot1mo06UbNKuX2YSnuK8fgI4/ERJuIpxryaxunxDCjRYhlNKrSoWJMKTcq0KFichkSQBIUlf9oWiKl4jUpcpeSN409YnOP2B8XkFzxM7jGDqJX8skctaKXLHZ3ASwM/uS9CmD4JhMVkPQiTG31CyQIIyl1fn95waDfSWx2iI9BK1y1Ijrfw2L1ZUlu7AVEzvW9Au5mcr1CGQuX4ewsgbkHUTu+72tovkOM2s09D24cFJz5WFvR/PHrbfQrXOxbzOXiBsDie3Erj8Oy3wsUvntfLKdwzpN6K2DfdZP9Uk33TDfZP1Zk+9Bj1Q48RTz8Gtf2Etf0UG/sptQ4x3jrIRHSIZUxxBlNssiNczDRFiwa+ZkxIvbSEZmHJ0ZDujC1aZ4wx7ZFRXEZQGiMojhEUK8kv/NG/rz1d9iQMSkugvARKS6E0kS4vSUOqAmmPKumNlTgaEp1zHD1XGrxhKb111ssQFk8Yl+3LPel1WtB/4Dbv3JPA97RH63Gy3HmiOu5mR3v6p32tqJU+STW6evu939sZT9D/+DhKz1k/9iTSeVI84cm6kLQlanU9YXa+pp7sS3vlR3+uCunP4NEn2a4n2zhKH5eex8odWjVoTUOzCq301qzCkrNO7/E7BQr3RWCq0WbyQJVd+2vs2jfNwQN7aR98mGBqN6Xqo4zV97CktZdV8X7OsCOcwRQX2RFWMEVo/X8RWlakGi6nMbaCdvkMvLKJ5vhKDixZTWnpaipLllEuj2GFclc4FpNArSyHygoYW0FQWsK4GeML/JgsqE5gjSozCAssSByYpU/cJSgvnf/rjTCF+zxrtiL27n2E/Y/s4shju6gf2E10+BFsag9hfR+FxkHG4ylWMMWlNs1VM/Ssq8Xl1MpraFdW4mPnEi1ZzeGlq6ksW0Nl2WpsYjWMr4KxlTC+imJpguWj2AsVEUDhPmemjhxi946fcmDXdlqP3kvx4C9YWdvJ+mg3663B+p7ja5SZClfQHFtOXFlJMH4eLF1Na/lqCstXY8vWwdJ1sGwtLDmb8WIl371nEZlTCvdT9NjBQzx03084suun+KPbmTh0H2fVf8l6HuXx6TGxG48EZ7F/bBPblz+N8IzNVM5Yx9LV61hx1kaWrNrAWHkpY0NtiYjkmcL9ZBpTHLj/Tia3f5/mrrs449B2Nse7WJ2Oc7c8ZHdhA3uWXcjkyhdRXvtEVm2+kLPPuZB15XHWDbl8ERldCvduzSrxL77No3d+kXDX91ndeJAzcM4A9rKCh8efwE/OfC7l9Rex+twns2bzBWwulNk87LpFRHoo3A9Nwj99heY9XybY+fcU4gZLvcKPuJCDK57J2KansOWiK3j8485nTaAXKEUkG0Y33Ft1/DMvw+79MgCP+Jl8PXoWk2v+BU++8vk87+JNVIoj/PY4Ecm00Qz3OCL67KsI7/0yH2j/Fl8Pn8Gllz6Vlzx9C084W++9FZHsG71wd8e//J8Jf76Vd7dewtqr38Stl29iojx6D4WI5NfIJZp/50+xOz7EX7V/k7VXv4lXPuPcYZckIjLnMv9ZPKfkrk9g33w3n4+u4NHL36JgF5HcGp2e+71/R7z1tXw3uohvPuFtfOAFFw67IhGReTMa4b7rR0S3vZTt0WY+tPadfOh3nkqgtzWKSI7lP9zbTdqfvJ6HohW8Y+nb+cjLnqG3OIpI7uV+zL29514K9X18KLyeP3/Fc1kxXhp2SSIi8y734f7Yzp8AcOU/v5KNK/W5iiIyGnIf7lO7fkbkxrrzLh52KSIiCyb34W57f85OP5tz164adikiIgsm9+G+5PB9TBY2sUT/gSoiIyTf4d5usLr5EAeXPG7YlYiILKhch3u89z5CYqJVj5/9YBGRHMl1uO9/IHmnzNh6/TeqiIyWXIf7kQd/StsDzjr3ScMuRURkQeU63H3vz3nAz+K8s1cPuxQRkQWV63BfcmgHD4abWD5eHHYpIiILKr/h3m6wqjnJwSXnDbsSEZEFl9tw98f+iZCY9qonDLsUEZEFl9twP/jgzwCo6J0yIjKCchvuhx9I3imzZrPCXURGz0DhbmZXm9m9ZrbDzG7us3+TmX3LzO4ys5+Y2fPnvtRT43vu4QE/i8et02fKiMjomTXczSwEbgGuAS4AbjCzC3oO+0PgNne/FLge+J9zXeipGj+8g53BJlZN6PPbRWT0DNJzvxzY4e73u3sTuBW4rucYB5aly8uB3XNX4mloN1jVmOTgxLmYaTo9ERk9g4T7emBX1/pkuq3bO4AbzWwSuB34/X4nMrObzGybmW3bu3fvaZQ7mM47ZVp6p4yIjKhBwr1f19d71m8APubuG4DnA39tZiec290/6O6Xuftla9asOfVqB3QkfadMeZ1eTBWR0TRIuE8CG7vWN3DisMsrgNsA3P37QAUY2v/8H971s+SdMlt6XxoQERkNg4T7HcD5ZnaOmZVIXjDd2nPMg8CvA5jZE0nCff7GXWYRP7o9eafMWn2mjIiMplnD3d3bwGuArwD3kLwr5m4ze5eZXZse9kbgVWb2j8CngZe5e+/QzYIZP3Qf99tGzlpWHlYJIiJDNdDcc+5+O8kLpd3b3ta1vB24Ym5LO02tOisbD7F/4kq9U0ZERlb+/kN13w4CYlorNfuSiIyu3IX79GTnnTJ6MVVERlfuwv3Qgz/RZ8qIyMjLXbhHj97DTj+b89bqM2VEZHTlLtzHDu7gftvA+hVjwy5FRGRo8hXurTorG5PsGz+XINA7ZURkdOUr3PfdR0BMc6U+U0ZERluuwr32UPJOmdJavVNGREZbrsL90IPJZ8qs2qRwF5HRlqtwjx7Zzk4/m8etXTnsUkREhipX4T526D52sIFNK8eHXYqIyFDlJ9xbdVbUH2Lf2DkUwvw0S0TkdOQnBdN3ytT1mTIiIvkJ9+bBhwFYunrTkCsRERm+3IT7I/v2A3D2Gn3sgIhIbsJ9euowAGesWDHkSkREhi834R7XpwEojS0ZciUiIsOXm3CPGkm4V8aXDbkSEZHhy024e3MKgLGJpUOuRERk+HIT7jSnaXiR8bHKsCsRERm63IS7N6tUKTNWDIddiojI0OUm3IN2lRplQn2Ou4hIfsLdWlUapiEZERHIUbgX2lXqpqn1REQgT+Ee1WgG6rmLiEDOwr0VqucuIgI5CvdSXKetnruICJCzcI8KmqRDRARyFO5lr9NWuIuIADkK9zHqeFHhLiICeQn3qEWRtsJdRCSVi3DvfGgYCncRESAn4d6oJuFupYkhVyIisjjkItzrnXAvK9xFRCAn4d6oHgIgLGsWJhERyEm4N9Oee1hRz11EBPIS7vUk3IsV9dxFRCAn4d6qdcJdU+yJiEBOwj1Ke+6lcYW7iAgMGO5mdrWZ3WtmO8zs5hmOebGZbTezu83sU3Nb5slFjSTcKxMalhERASjMdoCZhcAtwG8Ak8AdZrbV3bd3HXM+8GbgCnc/YGZnzlfB/cSNaQDK48sW8rIiIovWID33y4Ed7n6/uzeBW4Hreo55FXCLux8AcPc9c1vmyXka7uPj6rmLiMBg4b4e2NW1Pplu6/Z44PFm9j0z+4GZXd3vRGZ2k5ltM7Nte/fuPb2K+2lVqXmJsUpp7s4pIpJhg4S79dnmPesF4HzgKuAG4MNmtuKEL3L/oLtf5u6XrVmz5lRrnVmrSpUypTAXrw+LiPzKBknDSWBj1/oGYHefY77o7i13/yVwL0nYL4igNU3dKpj1ex4SERk9g4T7HcD5ZnaOmZWA64GtPcd8AXgWgJmtJhmmuX8uCz2ZoF2lYeWFupyIyKI3a7i7ext4DfAV4B7gNne/28zeZWbXpod9BdhnZtuBbwF/4O775qvoXmG7RtM0ObaISMesb4UEcPfbgdt7tr2ta9mBN6S3BVeIatQChbuISEcuXoEsxjVaYWXYZYiILBq5CPdSVKMdahYmEZGOXIR72etEoYZlREQ6chPuseZPFRE5KhfhXqFBXFC4i4h0ZD/c202KRFBSuIuIdGQ+3Nv1I8lCUVPsiYh0ZD7ca9Uk3K2kcBcR6ch8uDfScA8rGpYREenIQbgnszAFZX2Wu4hIRw7CPe25lzV/qohIR+bDvfOCamlcY+4iIh2ZD/dWLRmWKVY0LCMi0pH5cI/qSbiXxzUsIyLSkf1wTyfHLo8vG3IlIiKLR+bDPW6m4T6hYRkRkY7Mhztpz31CwzIiIkdlP9xb01S9zFipOOxKREQWjcyHu7Wq1CgTBDbsUkREFo3sh3u7Rs00xZ6ISLfMh3vYrtJQuIuIHCfz4V5oV2kGmmJPRKRb9sM9qtEK1HMXEemW+XAvxnVamhxbROQ4mQ/3UlyjHeqz3EVEumU+3MteJyqo5y4i0i3z4V7xOl5Uz11EpFu2w92dMa8TFxTuIiLdMh3u3q4TmoMmxxYROU6mw70zfyol9dxFRLplOtxr04cBCNRzFxE5TqbDvd4J97LCXUSkW6bDvZlOsRdq/lQRkeNkO9yrRwAoVjRRh4hIt0yHe7uWzMJUGFPPXUSkW6bDvdVIhmVKCncRkeNkOtyjdMy9PKZhGRGRbpkO9zjtuVfG1XMXEemW7XBvVgEYm1g+5EpERBaXgcLdzK42s3vNbIeZ3XyS437bzNzMLpu7Ek+iOU3sxti43ucuItJt1nA3sxC4BbgGuAC4wcwu6HPcUuC1wA/nusgZNatUKVMqhgt2SRGRLBik5345sMPd73f3JnArcF2f494NvA+oz2F9J2XtKnUrL9TlREQyY5BwXw/s6lqfTLcdZWaXAhvd/UtzWNuswtY0dTR/qohIr0HC3fps86M7zQLg/cAbZz2R2U1mts3Mtu3du3fwKmcQtGs0As3CJCLSa5BwnwQ2dq1vAHZ3rS8FngR828x2Ak8HtvZ7UdXdP+jul7n7ZWvWrDn9qlOFqEYzUM9dRKTXIOF+B3C+mZ1jZiXgemBrZ6e7H3L31e6+xd23AD8ArnX3bfNScZdiVKOlnruIyAlmDXd3bwOvAb4C3APc5u53m9m7zOza+S7wZIpxnXaocBcR6VUY5CB3vx24vWfb22Y49qpfvazBlOMa7YLCXUSkV6b/Q7XsdSJNji0icoJMh3uFOq5wFxE5QXbD3Z0xb+BFhbuISK/Mhnu7USUwB02OLSJygsyGe62aTI5tCncRkRNkNtzrU8n8qUFZ4S4i0iu74V5VuIuIzCSz4d6sJbMwhRXNwiQi0ivz4V5SuIuInCCz4d6uJy+oFjQ5tojICTIc7tMAlDU5tojICTIb7lEjDXf13EVETpDZcPdGMuZemVC4i4j0ymy4x80qAGMKdxGRE2Q23GlO0/aASlmfLSMi0ivT4V6jTBBmtwkiIvMls8kYtGvUTfOnioj0k+Fwr9JQuIuI9JXZcA/bNYW7iMgMMhvuxahKU5Nji4j0leFwr9EKFO4iIv1kN9zjOm313EVE+spsuJe9TlRQuIuI9JPpcI8L+gcmEZF+MhvuY14nLircRUT6yWS4exwxbg0oaoo9EZF+Mhnu9Wrycb+U1HMXEeknk+Fem05mYdLk2CIi/WUy3OvV5LPcraRwFxHpJ5Ph3qgmPfdQk2OLiPSVyXBv1o4AUFC4i4j0lclwb9WSYZmiwl1EpK9Mhnu7nrxbpjSmcBcR6Seb4Z5Ojl0a1/ypIiL9ZDLc43oS7mWFu4hIX9kM92YyLFOZWDbkSkREFqdMhrs3knAfn1DPXUSkn0yGO60qLQ8pljTNnohIP5kMd2tVqWn+VBGRGQ0U7mZ2tZnda2Y7zOzmPvvfYGbbzewnZvYNM9s896V2Xa9VpU55Pi8hIpJps4a7mYXALcA1wAXADWZ2Qc9hdwGXufvFwGeA9811od3CqEpDPXcRkRkN0nO/HNjh7ve7exO4Fbiu+wB3/5a7V9PVHwAb5rbM4xXaNZqaHFtEZEaDhPt6YFfX+mS6bSavAL78qxQ1m0JUoxmo5y4iMpPCAMdYn23e90CzG4HLgGfOsP8m4CaATZs2DVjiiYpxjUZh+Wl/vYhI3g3Sc58ENnatbwB29x5kZs8B/gtwrbs3+p3I3T/o7pe5+2Vr1qw5nXoBKMV1ooKGZUREZjJIuN8BnG9m55hZCbge2Np9gJldCvwVSbDvmfsyj1eJa0Shwl1EZCazhru7t4HXAF8B7gFuc/e7zexdZnZtetifAEuAvzWzH5vZ1hlONyfK1ImLmj9VRGQmg4y54+63A7f3bHtb1/Jz5riukxrzBq5wFxGZUeb+Q7XVbjNmTShq/lQRkZlkLtyr08kUe2hybBGRGWUu3BtpuAdlhbuIyEwyF+616mEAQoW7iMiMMhfuzWoyC1OoybFFRGaUuXBv1ZJwLyjcRURmlLlwb9aTMXeFu4jIzDIX7lFNk2OLiMwmc+HerqfhPqaeu4jITDIX7nEzmRy7op67iMiMMhfu3gn3CX3kr4jITAb6bJnF5KInXcKh9jUsm9CwjIjITDIX7ksvuQ4uuW72A0VERljmhmVERGR2CncRkRxSuIuI5JDCXUQkhxTuIiI5pHAXEckhhbuISA4p3EVEcsjcfTgXNtsLPDDLYauBxxagnMVG7R4to9puGN22/yrt3uzua2Y7aGjhPggz2+bulw27joWmdo+WUW03jG7bF6LdGpYREckhhbuISA4t9nD/4LALGBK1e7SMarthdNs+7+1e1GPuIiJyehZ7z11ERE6Dwl1EJIcWbbib2dVmdq+Z7TCzm4ddz3wxs4+a2R4z+1nXtpVm9jUzuy+9P2OYNc4HM9toZt8ys3vM7G4ze126PddtN7OKmf3IzP4xbfc70+3nmNkP03b/jZmVhl3rfDCz0MzuMrMvpeu5b7eZ7TSzn5rZj81sW7pt3n/OF2W4m1kI3AJcA1wA3GBmFwy3qnnzMeDqnm03A99w9/OBb6TredMG3ujuTwSeDvyH9Huc97Y3gGe7+5OBS4CrzezpwH8D3p+2+wDwiiHWOJ9eB9zTtT4q7X6Wu1/S9d72ef85X5ThDlwO7HD3+929CdwK5HJuPXf/e2B/z+brgI+nyx8HXrigRS0Ad3/Y3f8hXT5C8gu/npy33RNT6WoxvTnwbOAz6fbctRvAzDYAvwl8OF03RqDdM5j3n/PFGu7rgV1d65PptlFxlrs/DEkIAmcOuZ55ZWZbgEuBHzICbU+HJn4M7AG+BvwCOOju7fSQvP68/znwn4A4XV/FaLTbga+a2Z1mdlO6bd5/zhfrBNnWZ5ves5lDZrYE+Czwenc/nHTm8s3dI+ASM1sBfB54Yr/DFraq+WVmLwD2uPudZnZVZ3OfQ3PV7tQV7r7bzM4EvmZmP1+Iiy7WnvsksLFrfQOwe0i1DMOjZrYWIL3fM+R65oWZFUmC/ZPu/rl080i0HcDdDwLfJnnNYYWZdTpbefx5vwK41sx2kgyzPpukJ5/3duPuu9P7PSRP5pezAD/nizXc7wDOT19JLwHXA1uHXNNC2gq8NF1+KfDFIdYyL9Lx1o8A97j7n3XtynXbzWxN2mPHzMaA55C83vAt4LfTw3LXbnd/s7tvcPctJL/P33T3l5DzdpvZhJkt7SwDzwV+xgL8nC/a/1A1s+eTPLOHwEfd/b1DLmlemNmngatIPgL0UeDtwBeA24BNwIPAv3L33hddM83MrgS+A/yUY2OwbyEZd89t283sYpIX0EKSztVt7v4uMzuXpEe7ErgLuNHdG8OrdP6kwzJvcvcX5L3dafs+n64WgE+5+3vNbBXz/HO+aMNdRERO32IdlhERkV+Bwl1EJIcU7iIiOaRwFxHJIYW7iEgOKdxFRHJI4S4ikkP/HyTBo/Rxsh8nAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.plot(np.arange(1,niter+1), loss2)\n",
    "plt.title('training loss')\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(np.arange(1,niter+1), accuracy_train)\n",
    "plt.plot(np.arange(1,niter+1), accuracy_test)\n",
    "plt.title('validation accuracy')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
